{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wJcYs_ERTnnI"
      },
      "source": [
        "##### Copyright 2021 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "HMUDt0CiUJk9"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "77z2OchJTk0l"
      },
      "source": [
        "# Migrate SessionRunHook to Keras callbacks\n",
        "\n",
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/guide/migrate/sessionrunhook_callback\">\n",
        "    <img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />\n",
        "    View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/guide/migrate/sessionrunhook_callback.ipynb\">\n",
        "    <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />\n",
        "    Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/guide/migrate/sessionrunhook_callback.ipynb\">\n",
        "    <img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />\n",
        "    View source on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/guide/migrate/sessionrunhook_callback.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KZHPY55aFyXT"
      },
      "source": [
        "In TensorFlow 1, to customize the behavior of training, you use `tf.estimator.SessionRunHook` with `tf.estimator.Estimator`. This guide demonstrates how to migrate from `SessionRunHook` to TensorFlow 2's custom callbacks with the `tf.keras.callbacks.Callback` API, which works with Keras `Model.fit` for training (as well as `Model.evaluate` and `Model.predict`). You will learn how to do this by implementing a `SessionRunHook` and a `Callback` task that measures examples per second during training.\n",
        "\n",
        "Examples of callbacks are checkpoint saving (`tf.keras.callbacks.ModelCheckpoint`) and [TensorBoard](`tf.keras.callbacks.TensorBoard`) summary writing. Keras [callbacks](https://www.tensorflow.org/guide/keras/custom_callback) are objects that are called at different points during training/evaluation/prediction in the built-in Keras `Model.fit`/`Model.evaluate`/`Model.predict` APIs. You can learn more about callbacks in the `tf.keras.callbacks.Callback` API docs, as well as the [Writing your own callbacks](https://www.tensorflow.org/guide/keras/custom_callback.ipynb/) and [Training and evaluation with the built-in methods](https://www.tensorflow.org/guide/keras/train_and_evaluate) (the *Using callbacks* section) guides."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "29da56bf859d"
      },
      "source": [
        "## Setup\n",
        "\n",
        "Start with imports and a simple dataset for demonstration purposes:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "296d8b0DoKpV"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "import tensorflow.compat.v1 as tf1\n",
        "\n",
        "import time\n",
        "from datetime import datetime\n",
        "from absl import flags"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xVGYtUXyXNuE"
      },
      "outputs": [],
      "source": [
        "features = [[1., 1.5], [2., 2.5], [3., 3.5]]\n",
        "labels = [[0.3], [0.5], [0.7]]\n",
        "eval_features = [[4., 4.5], [5., 5.5], [6., 6.5]]\n",
        "eval_labels = [[0.8], [0.9], [1.]]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ON4zQifT0Vec"
      },
      "source": [
        "## TensorFlow 1: Create a custom SessionRunHook with tf.estimator APIs\n",
        "\n",
        "The following TensorFlow 1 examples show how to set up a custom `SessionRunHook` that measures examples per second during training. After creating the hook (`LoggerHook`), pass it to the `hooks` parameter of `tf.estimator.Estimator.train`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "S-myEclbXUL7"
      },
      "outputs": [],
      "source": [
        "def _input_fn():\n",
        "  return tf1.data.Dataset.from_tensor_slices(\n",
        "      (features, labels)).batch(1).repeat(100)\n",
        "\n",
        "def _model_fn(features, labels, mode):\n",
        "  logits = tf1.layers.Dense(1)(features)\n",
        "  loss = tf1.losses.mean_squared_error(labels=labels, predictions=logits)\n",
        "  optimizer = tf1.train.AdagradOptimizer(0.05)\n",
        "  train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())\n",
        "  return tf1.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Xd9sPTkO0ZTD"
      },
      "outputs": [],
      "source": [
        "class LoggerHook(tf1.train.SessionRunHook):\n",
        "  \"\"\"Logs loss and runtime.\"\"\"\n",
        "\n",
        "  def begin(self):\n",
        "    self._step = -1\n",
        "    self._start_time = time.time()\n",
        "    self.log_frequency = 10\n",
        "\n",
        "  def before_run(self, run_context):\n",
        "    self._step += 1\n",
        "\n",
        "  def after_run(self, run_context, run_values):\n",
        "    if self._step % self.log_frequency == 0:\n",
        "      current_time = time.time()\n",
        "      duration = current_time - self._start_time\n",
        "      self._start_time = current_time\n",
        "      examples_per_sec = self.log_frequency / duration\n",
        "      print('Time:', datetime.now(), ', Step #:', self._step,\n",
        "            ', Examples per second:', examples_per_sec)\n",
        "\n",
        "estimator = tf1.estimator.Estimator(model_fn=_model_fn)\n",
        "\n",
        "# Begin training.\n",
        "estimator.train(_input_fn, hooks=[LoggerHook()])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3uZCDMrM2CEg"
      },
      "source": [
        "## TensorFlow 2: Create a custom Keras callback for Model.fit\n",
        "\n",
        "In TensorFlow 2, when you use the built-in Keras `Model.fit` (or `Model.evaluate`) for training/evaluation, you can configure a custom `tf.keras.callbacks.Callback`, which you then pass to the `callbacks` parameter of `Model.fit` (or `Model.evaluate`). (Learn more in the [Writing your own callbacks](../..guide/keras/custom_callback.ipynb) guide.)\n",
        "\n",
        "In the example below, you will write a custom `tf.keras.callbacks.Callback` that logs various metrics—it will measure examples per second, which should be comparable to the metrics in the previous `SessionRunHook` example."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UbMPoiB92KRG"
      },
      "outputs": [],
      "source": [
        "class CustomCallback(tf.keras.callbacks.Callback):\n",
        "\n",
        "    def on_train_begin(self, logs = None):\n",
        "      self._step = -1\n",
        "      self._start_time = time.time()\n",
        "      self.log_frequency = 10\n",
        "\n",
        "    def on_train_batch_begin(self, batch, logs = None):\n",
        "      self._step += 1\n",
        "\n",
        "    def on_train_batch_end(self, batch, logs = None):\n",
        "      if self._step % self.log_frequency == 0:\n",
        "        current_time = time.time()\n",
        "        duration = current_time - self._start_time\n",
        "        self._start_time = current_time\n",
        "        examples_per_sec = self.log_frequency / duration\n",
        "        print('Time:', datetime.now(), ', Step #:', self._step,\n",
        "              ', Examples per second:', examples_per_sec)\n",
        "\n",
        "callback = CustomCallback()\n",
        "\n",
        "dataset = tf.data.Dataset.from_tensor_slices(\n",
        "    (features, labels)).batch(1).repeat(100)\n",
        "\n",
        "model = tf.keras.models.Sequential([tf.keras.layers.Dense(1)])\n",
        "optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05)\n",
        "\n",
        "model.compile(optimizer, \"mse\")\n",
        "\n",
        "# Begin training.\n",
        "result = model.fit(dataset, callbacks=[callback], verbose = 0)\n",
        "# Provide the results of training metrics.\n",
        "result.history"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EFqFi21Ftskq"
      },
      "source": [
        "## Next steps\n",
        "\n",
        "Learn more about callbacks in:\n",
        "\n",
        "- API docs: `tf.keras.callbacks.Callback`\n",
        "- Guide: [Writing your own callbacks](../..guide/keras/custom_callback.ipynb/)\n",
        "- Guide: [Training and evaluation with the built-in methods](https://www.tensorflow.org/guide/keras/train_and_evaluate) (the *Using callbacks* section)\n",
        "\n",
        "You may also find the following migration-related resources useful:\n",
        "\n",
        "- The [Early stopping migration guide](early_stopping.ipynb): `tf.keras.callbacks.EarlyStopping` is a built-in early stopping callback\n",
        "- The [TensorBoard migration guide](tensorboard.ipynb): TensorBoard enables tracking and displaying metrics\n",
        "- The [LoggingTensorHook and StopAtStepHook to Keras callbacks migration guide](logging_stop_hook.ipynb)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "sessionrunhook_callback.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
